import numpy as np


np.sort()
"""
Environment during development:
        numpy : 1.18.5
"""

class KMeans:
    """
    --------------------------------------------------------
    example for KMeans class

    >>> from sklearn.datasets import load_iris #import iris dataset
    >>> import numpy as np
    >>> iris = load_iris()
    >>> K = KMeans(3,init='random')
    >>> _,t = K.fit(np.array(iris.data))
    >>> print(t[:,1])
    >>> print(iris.target)
    """
    def __init__(self, n_clusters=8,init='random' ,max_iter=300,tol=1e-4):
        """
        n_clusters: The number of clusters
        max_iter: The maximum number of iterations
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.init = init
        self.tol = tol
        self.__init_method = {"random":self.__get_center,"kmeans++":self.__kpp}
    def fit(self,x):
        iter = 0
        m, _ = x.shape
        method = self.__init_method[self.init]
        centroids = method(x)
        self.__kpp(x)
        cluster_assment = np.array(np.zeros((m, 2)))
        while iter < self.max_iter:
            iter += 1
            for i in range(m):
                dis = np.inf
                min_index = -1
                for j in range(self.n_clusters):
                    temp = self.__distance(x[i],centroids[j])
                    if temp<dis:
                        dis = temp
                        min_index = j
                cluster_assment[i,0] = dis
                cluster_assment[i,1] = min_index
            for c in range(self.n_clusters):
                mark = cluster_assment[:, 1]
                t = np.nonzero(mark == c)
                filter_x = x[t]
                if filter_x.shape[0] > 0:
                    centroids[c, :] = np.mean(filter_x, axis=0)
        return centroids, cluster_assment


    def __distance(self,v1,v2)->float:
        return np.sqrt(np.sum(np.power(v1 - v2, 2)))

    def __get_center(self, x):
        m, _ = x.shape
        centroids = x.take(np.random.choice(m, self.n_clusters), axis=0)
        return centroids
    
    def __kpp(self,x):
        """
        Use kmeans++ algorithm to initialize cluster centers
        """
        m, _ = x.shape
        centers = self.__get_center(x)
        print(centers)
        distance = np.zeros((m,self.n_clusters))
        re = []
        re.append(centers[0])
        for j in range(self.n_clusters-1):
            for i in range(m):
                distance[i,j] = self.__distance(x[i],centers[j])
            total_distance = np.sum(distance[:,j],axis=0)
            temp = np.copy(distance[:,j]/total_distance)
            flag = np.random.choice(np.arange(m),1, replace=False, p=temp)
            
            centers[j+1] = x[flag]
        return centers
